# gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
# foundation models from Google.

load("@rules_license//rules:license.bzl", "license")

package(
    default_applicable_licenses = [
        "//:license",  # Placeholder comment, do not modify
    ],
    # Placeholder for internal compatible_with
    default_visibility = ["//visibility:public"],
)

license(
    name = "license",
    package_name = "gemma_cpp",
)

# Dual-licensed Apache 2 and 3-clause BSD.
licenses(["notice"])

exports_files(["LICENSE"])

cc_library(
    name = "basics",
    hdrs = ["util/basics.h"],
    deps = [
        "@highway//:hwy",
    ],
)

cc_library(
    name = "threading",
    srcs = ["util/threading.cc"],
    hdrs = ["util/threading.h"],
    deps = [
        ":basics",
        # Placeholder for container detection, do not remove
        "@highway//:hwy",
        "@highway//:thread_pool",
        "@highway//:topology",
    ],
)

cc_library(
    name = "allocator",
    srcs = ["util/allocator.cc"],
    hdrs = ["util/allocator.h"],
    deps = [
        ":basics",
        ":threading",
        "@highway//:hwy",
        "@highway//:thread_pool",
    ],
)

cc_library(
    name = "test_util",
    hdrs = ["util/test_util.h"],
    deps = [
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:stats",
    ],
)

cc_test(
    name = "threading_test",
    srcs = ["util/threading_test.cc"],
    deps = [
        ":threading",
        "@googletest//:gtest_main",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:nanobenchmark",
        "@highway//:thread_pool",
    ],
)

# For building all tests in one command, so we can test several.
test_suite(
    name = "ops_tests",
    tags = ["ops_tests"],
)

cc_library(
    name = "ops",
    srcs = [
        "ops/matmul.cc",
    ],
    hdrs = [
        "ops/matmul.h",
        "ops/ops.h",
    ],
    textual_hdrs = [
        "ops/dot-inl.h",
        "ops/sum-inl.h",
        "ops/fp_arith-inl.h",
        "ops/matmul-inl.h",
        "ops/matvec-inl.h",
        "ops/ops-inl.h",
    ],
    deps = [
        ":allocator",
        ":basics",
        ":threading",
        "//compression:compress",
        "@highway//:algo",
        "@highway//:bit_set",
        "@highway//:hwy",
        "@highway//:math",
        "@highway//:matvec",
        "@highway//:nanobenchmark",
        "@highway//:profiler",
        "@highway//:thread_pool",
        "@highway//:topology",
        "@highway//hwy/contrib/sort:vqsort",
    ],
)

cc_test(
    name = "dot_test",
    size = "small",
    timeout = "long",
    srcs = ["ops/dot_test.cc"],
    local_defines = ["HWY_IS_TEST"],
    # for test_suite.
    tags = ["ops_tests"],
    deps = [
        ":allocator",
        ":ops",
        ":test_util",
        ":threading",
        "@googletest//:gtest_main",  # buildcleaner: keep
        "//:app",
        "//compression:compress",
        "//compression:test_util",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:nanobenchmark",  #buildcleaner: keep
        "@highway//:profiler",
        "@highway//:stats",
        "@highway//:thread_pool",
    ],
)

cc_test(
    name = "ops_test",
    size = "small",
    timeout = "eternal",
    srcs = ["ops/ops_test.cc"],
    local_defines = ["HWY_IS_TEST"],
    # for test_suite.
    tags = ["ops_tests"],
    deps = [
        ":allocator",
        ":common",
        ":ops",
        ":test_util",
        "@googletest//:gtest_main",  # buildcleaner: keep
        "//:app",
        "//compression:compress",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:nanobenchmark",  #buildcleaner: keep
    ],
)

cc_test(
    name = "gemma_matvec_test",
    size = "small",
    timeout = "long",
    srcs = ["ops/gemma_matvec_test.cc"],
    local_defines = ["HWY_IS_TEST"],
    # for test_suite.
    tags = ["ops_tests"],
    deps = [
        ":ops",
        "@googletest//:gtest_main",  # buildcleaner: keep
        "//compression:compress",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:thread_pool",
    ],
)

cc_test(
    name = "matmul_test",
    size = "small",
    timeout = "long",
    srcs = ["ops/matmul_test.cc"],
    local_defines = ["HWY_IS_TEST"],
    # for test_suite.
    tags = ["ops_tests"],
    deps = [
        ":allocator",
        ":basics",
        ":ops",
        ":threading",
        "@googletest//:gtest_main",  # buildcleaner: keep
        "//compression:compress",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:thread_pool",
    ],
)

cc_test(
    name = "bench_matmul",
    size = "small",
    timeout = "long",
    srcs = ["ops/bench_matmul.cc"],
    local_defines = ["HWY_IS_TEST"],
    tags = [
        "manual",
        "notap",
        "ops_tests",  # for test_suite.
    ],
    deps = [
        ":allocator",
        ":basics",
        ":ops",
        ":threading",
        "@googletest//:gtest_main",  # buildcleaner: keep
        "//compression:compress",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:nanobenchmark",
        "@highway//:profiler",
        "@highway//:thread_pool",
    ],
)

cc_library(
    name = "common",
    srcs = [
        "gemma/common.cc",
        "gemma/configs.cc",
        "gemma/tensor_index.cc",
    ],
    hdrs = [
        "gemma/common.h",
        "gemma/configs.h",
        "gemma/tensor_index.h",
    ],
    deps = [
        "//compression:fields",
        "//compression:sfp",
        "@highway//:hwy",  # base.h
        "@highway//:thread_pool",
    ],
)

cc_test(
    name = "configs_test",
    srcs = ["gemma/configs_test.cc"],
    deps = [
        ":common",
        "@googletest//:gtest_main",
        "@highway//:hwy",
    ],
)

cc_test(
    name = "tensor_index_test",
    srcs = ["gemma/tensor_index_test.cc"],
    deps = [
        ":basics",
        ":common",
        ":weights",
        "@googletest//:gtest_main",
        "//compression:compress",
        "@highway//:hwy",
    ],
)

cc_library(
    name = "weights",
    srcs = ["gemma/weights.cc"],
    hdrs = ["gemma/weights.h"],
    deps = [
        ":common",
        "//compression:blob_store",
        "//compression:compress",
        "//compression:io",
        "@highway//:hwy",
        "@highway//:profiler",
        "@highway//:stats",
        "@highway//:thread_pool",
    ],
)

cc_library(
    name = "tokenizer",
    srcs = ["gemma/tokenizer.cc"],
    hdrs = ["gemma/tokenizer.h"],
    deps = [
        ":common",
        "//compression:io",
        "//compression:sfp",
        "@highway//:hwy",
        "@highway//:profiler",
        "@com_google_sentencepiece//:sentencepiece_processor",
    ],
)

cc_library(
    name = "kv_cache",
    srcs = ["gemma/kv_cache.cc"],
    hdrs = ["gemma/kv_cache.h"],
    deps = [
        ":common",
        "@highway//:hwy",
    ],
)

cc_library(
    name = "gemma_lib",
    srcs = [
        "gemma/gemma.cc",
        "gemma/instantiations/bf16.cc",
        "gemma/instantiations/f32.cc",
        "gemma/instantiations/nuq.cc",
        "gemma/instantiations/sfp.cc",
    ],
    hdrs = [
        "gemma/activations.h",
        "gemma/gemma.h",
    ],
    exec_properties = {
        # Avoid linker OOMs when building with sanitizer instrumentation.
        "mem": "28g",
    },
    textual_hdrs = [
        "gemma/gemma-inl.h",
        # Placeholder for internal file2, do not remove,
    ],
    deps = [
        ":allocator",
        ":basics",
        ":common",
        ":ops",
        ":tokenizer",
        ":kv_cache",
        ":weights",
        ":threading",
        "//compression:compress",
        "//compression:io",
        "//compression:sfp",
        "//paligemma:image",
        "@highway//:hwy",
        "@highway//:bit_set",
        "@highway//:nanobenchmark",  # timer
        "@highway//:profiler",
        "@highway//:thread_pool",
    ],
)

cc_library(
    name = "cross_entropy",
    srcs = ["evals/cross_entropy.cc"],
    hdrs = ["evals/cross_entropy.h"],
    deps = [
        ":common",
        ":gemma_lib",
        ":ops",
        "@highway//:hwy",
    ],
)

cc_library(
    name = "args",
    hdrs = ["util/args.h"],
    deps = [
        ":basics",
        "//compression:io",
        "@highway//:hwy",
    ],
)

cc_library(
    name = "app",
    hdrs = ["util/app.h"],
    deps = [
        ":args",
        ":basics",
        ":common",
        ":gemma_lib",
        ":ops",
        ":threading",
        "//compression:io",
        "//compression:sfp",
        "@highway//:hwy",
    ],
)

cc_library(
    name = "benchmark_helper",
    srcs = ["evals/benchmark_helper.cc"],
    hdrs = ["evals/benchmark_helper.h"],
    deps = [
        ":app",
        ":args",
        ":common",
        ":cross_entropy",
        ":gemma_lib",
        ":kv_cache",
        ":threading",
        # Placeholder for internal dep, do not remove.,
        "@google_benchmark//:benchmark",
        "//compression:compress",
        "@highway//:hwy",
        "@highway//:nanobenchmark",
        "@highway//:topology",
    ],
)

cc_test(
    name = "gemma_test",
    srcs = ["evals/gemma_test.cc"],
    linkstatic = True,
    # Requires model files
    tags = [
        "local",
        "manual",
        "no_tap",
    ],
    deps = [
        ":benchmark_helper",
        ":common",
        ":gemma_lib",
        "@googletest//:gtest_main",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
    ],
)

cc_test(
    name = "gemma_batch_bench",
    srcs = ["evals/gemma_batch_bench.cc"],
    # Requires model files
    tags = [
        "local",
        "manual",
        "no_tap",
    ],
    deps = [
        ":benchmark_helper",
        ":common",
        ":gemma_lib",
        ":tokenizer",
        "@googletest//:gtest_main",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
    ],
)

cc_binary(
    name = "gemma",
    srcs = ["gemma/run.cc"],
    deps = [
        ":app",
        ":args",
        ":benchmark_helper",
        ":common",
        ":gemma_lib",
        ":threading",
        # Placeholder for internal dep, do not remove.,
        "//compression:sfp",
        "//paligemma:image",
        "@highway//:hwy",
        "@highway//:profiler",
        "@highway//:thread_pool",
    ],
)

cc_binary(
    name = "single_benchmark",
    srcs = ["evals/benchmark.cc"],
    deps = [
        ":args",
        ":benchmark_helper",
        ":common",
        ":cross_entropy",
        ":gemma_lib",
        "//compression:io",
        "@highway//:hwy",
        "@highway//:nanobenchmark",
        "@nlohmann_json//:json",
    ],
)

cc_library(
    name = "benchmark_prompts",
    hdrs = ["evals/prompts.h"],
    deps = ["@highway//:hwy"],
)

cc_binary(
    name = "benchmarks",
    srcs = [
        "evals/benchmarks.cc",
        "evals/prompts.h",
    ],
    deps = [
        ":benchmark_helper",
        ":benchmark_prompts",
        "@google_benchmark//:benchmark",
        "@highway//:hwy",  # base.h
    ],
)

cc_binary(
    name = "debug_prompt",
    srcs = [
        "evals/debug_prompt.cc",
    ],
    deps = [
        ":args",
        ":benchmark_helper",
        ":gemma_lib",
        "//compression:io",
        "@highway//:hwy",
        "@nlohmann_json//:json",
    ],
)

cc_binary(
    name = "gemma_mmlu",
    srcs = ["evals/run_mmlu.cc"],
    deps = [
        ":args",
        ":benchmark_helper",
        ":gemma_lib",
        "//compression:io",
        "@highway//:hwy",
        "@highway//:profiler",
        "@highway//:thread_pool",
        "@nlohmann_json//:json",
    ],
)

cc_library(
    name = "prompt",
    hdrs = ["backprop/prompt.h"],
    deps = [],
)

cc_library(
    name = "sampler",
    hdrs = ["backprop/sampler.h"],
    deps = [
        ":prompt",
    ],
)

cc_library(
    name = "backprop",
    srcs = [
        "backprop/backward.cc",
        "backprop/forward.cc",
    ],
    hdrs = [
        "backprop/activations.h",
        "backprop/backward.h",
        "backprop/forward.h",
    ],
    textual_hdrs = [
        "backprop/backward-inl.h",
        "backprop/forward-inl.h",
    ],
    deps = [
        ":allocator",
        ":common",
        ":ops",
        ":prompt",
        ":weights",
        "//compression:compress",
        "@highway//:dot",
        "@highway//:hwy",  # base.h
        "@highway//:thread_pool",
    ],
)

cc_library(
    name = "backprop_scalar",
    hdrs = [
        "backprop/activations.h",
        "backprop/backward_scalar.h",
        "backprop/common_scalar.h",
        "backprop/forward_scalar.h",
    ],
    deps = [
        ":common",
        ":prompt",
        ":weights",
        "//compression:compress",
        "@highway//:hwy",
    ],
)

cc_test(
    name = "backward_scalar_test",
    size = "large",
    srcs = [
        "backprop/backward_scalar_test.cc",
        "backprop/test_util.h",
    ],
    deps = [
        ":backprop_scalar",
        ":common",
        ":prompt",
        ":sampler",
        ":weights",
        "@googletest//:gtest_main",
        "//compression:compress",
        "@highway//:thread_pool",
    ],
)

cc_test(
    name = "backward_test",
    size = "large",
    srcs = [
        "backprop/backward_test.cc",
        "backprop/test_util.h",
    ],
    exec_properties = {
        # Avoid linker OOMs when building with sanitizer instrumentation.
        "mem": "28g",
    },
    deps = [
        ":allocator",
        ":backprop",
        ":backprop_scalar",
        ":common",
        ":gemma_lib",
        ":ops",
        ":prompt",
        ":sampler",
        ":weights",
        "@googletest//:gtest_main",
        "//:threading",
        "//compression:compress",
        "@highway//:hwy",
        "@highway//:hwy_test_util",
        "@highway//:thread_pool",
    ],
)

cc_library(
    name = "optimizer",
    srcs = ["backprop/optimizer.cc"],
    hdrs = ["backprop/optimizer.h"],
    deps = [
        ":allocator",
        ":common",
        ":weights",
        "//compression:compress",
        "@highway//:hwy",
        "@highway//:thread_pool",
    ],
)

cc_test(
    name = "optimize_test",
    srcs = [
        "backprop/optimize_test.cc",
    ],
    exec_properties = {
        # Avoid linker OOMs when building with sanitizer instrumentation.
        "mem": "28g",
    },
    deps = [
        ":allocator",
        ":backprop",
        ":basics",
        ":common",
        ":gemma_lib",
        ":ops",
        ":optimizer",
        ":prompt",
        ":sampler",
        ":threading",
        ":weights",
        "@googletest//:gtest_main",
        "//compression:sfp",
        "@highway//:thread_pool",
    ],
)
